/**
  * Created by ariellemoro on 22.11.16.
  */

import collection.mutable.{ListBuffer, Map}
import scala.math.{abs, min, sqrt, pow}

object KDTree {

  var kdgraph = new ListBuffer[Map[ListBuffer[Int], ListBuffer[ListBuffer[Int]]]]

  def main(args: Array[String]): Unit = {
    //Point list : (2,3), (5,4), (9,6), (4,7), (8,1), (7,2)
    var points = new ListBuffer[ListBuffer[Int]]
    var point1 = new ListBuffer[Int]
    point1 += (2,3)
    var point2 = new ListBuffer[Int]
    point2 += (5,4)
    var point3 = new ListBuffer[Int]
    point3 += (9,6)
    var point4 = new ListBuffer[Int]
    point4 += (4,7)
    var point5 = new ListBuffer[Int]
    point5 += (8,1)
    var point6 = new ListBuffer[Int]
    point6 += (7,2)
    points += (point1, point2, point3, point4, point5, point6)
    val depth = 0
    //Print the KDTree
    val firstPoint = buildKDTree(points, depth)
    //print(firstPoint)
    println(s"KDTree links (starting point: $firstPoint) =>")
    for(link <- kdgraph){
      println(link)
    }
    //Search test
    var searchPoint = new ListBuffer[Int]
    searchPoint += (9,4)
    val nearestPoint = nearestNeighbourSearch(searchPoint, firstPoint)
    println(s"The nearest point on $searchPoint is $nearestPoint")
  }

  def buildKDTree(points: ListBuffer[ListBuffer[Int]], depth: Int): ListBuffer[Int] = {
    if(points.length == 0){
      return new ListBuffer[Int]
    }
    val k = points(0).length
    val axis = depth % k
    val median = points.length / 2
    var sortedPoints = points.sortWith((p1, p2) => p1(axis) < p2(axis))
    var medianPoint = sortedPoints(median)
    sortedPoints -= medianPoint
    val leftPoints = sortedPoints.filter(p => p(axis) < medianPoint(axis))
    var leftPoint = buildKDTree(leftPoints, depth+1)
    val rightPoints = sortedPoints.filter(p => p(axis) > medianPoint(axis))
    var rightPoint = buildKDTree(rightPoints, depth+1)
    var nextPoints = new ListBuffer[ListBuffer[Int]]
    nextPoints += leftPoint
    nextPoints += rightPoint
    kdgraph += Map(medianPoint -> nextPoints)
    return medianPoint
  }

  def nearestNeighbourSearch(searchPoint: ListBuffer[Int], startingPoint:ListBuffer[Int]): ListBuffer[Int] = {
    var currentPoint = startingPoint
    var currentBest = currentPoint
    var currentBestScore = compareTwoPoints(searchPoint, currentPoint)
    var stopExploration = true
    while(stopExploration) {
      stopExploration = false
      //Find the currentPoint
      for (link <- kdgraph) {
        val keys = link.keys //Warning: there is obviously just one key ;)
        for (key <- keys) {
          if (key == currentPoint) {
            //Extract destinations (values of the map)
            val destinationPoints = link(key)
            for (destinationPoint <- destinationPoints) {
              if (destinationPoint.length != 0) {
                val score = compareTwoPoints(searchPoint, destinationPoint)
                if (score < currentBestScore) {
                  //One of the child of the currentPoint has a better score than the current point, we will continue to explore
                  currentBestScore = score
                  currentBest = destinationPoint
                  stopExploration = true
                }
              }
            }
          }
        }
      }
      currentPoint = currentBest
    }
    return currentBest
  }

  def compareTwoPoints(searchPoint: ListBuffer[Int], point:ListBuffer[Int]): Double = {
    return sqrt(pow((searchPoint(0) - point(0)), 2) + pow((searchPoint(1) - point(1)), 2))
  }

}